{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp models.deepar"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DeepAR"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The DeepAR model produces probabilistic forecasts based on an autoregressive recurrent neural network optimized on panel data using cross-learning. DeepAR obtains its forecast distribution uses a Markov Chain Monte Carlo sampler with the following conditional probability:\n",
    "$$\\mathbb{P}(\\mathbf{y}_{[t+1:t+H]}|\\;\\mathbf{y}_{[:t]},\\; \\mathbf{x}^{(f)}_{[:t+H]},\\; \\mathbf{x}^{(s)})$$\n",
    "\n",
    "where $\\mathbf{x}^{(s)}$ are static exogenous inputs, $\\mathbf{x}^{(f)}_{[:t+H]}$ are future exogenous available at the time of the prediction.\n",
    "The predictions are obtained by transforming the hidden states $\\mathbf{h}_{t}$ into predictive distribution parameters $\\theta_{t}$, and then generating samples $\\mathbf{\\hat{y}}_{[t+1:t+H]}$ through Monte Carlo sampling trajectories.\n",
    "\n",
    "\\begin{align}\n",
    "\\mathbf{h}_{t} &= \\textrm{RNN}([\\mathbf{y}_{t},\\mathbf{x}^{(f)}_{t+1},\\mathbf{x}^{(s)}], \\mathbf{h}_{t-1})\\\\\n",
    "\\mathbf{\\theta}_{t}&=\\textrm{Linear}(\\mathbf{h}_{t}) \\\\\n",
    "\\hat{y}_{t+1}&=\\textrm{sample}(\\;\\mathrm{P}(y_{t+1}\\;|\\;\\mathbf{\\theta}_{t})\\;)\n",
    "\\end{align}\n",
    "\n",
    "**References**<br>\n",
    "- [David Salinas, Valentin Flunkert, Jan Gasthaus, Tim Januschowski (2020). \"DeepAR: Probabilistic forecasting with autoregressive recurrent networks\". International Journal of Forecasting.](https://www.sciencedirect.com/science/article/pii/S0169207019301888)<br>\n",
    "- [Alexander Alexandrov et. al (2020). \"GluonTS: Probabilistic and Neural Time Series Modeling in Python\". Journal of Machine Learning Research.](https://www.jmlr.org/papers/v21/19-820.html)<br>\n",
    "\n",
    "\n",
    ":::{.callout-warning collapse=\"false\"}\n",
    "#### Exogenous Variables, Losses, and Parameters Availability\n",
    "\n",
    "Given the sampling procedure during inference, DeepAR only supports `DistributionLoss` as training loss.\n",
    "\n",
    "Note that DeepAR generates a non-parametric forecast distribution using Monte Carlo. We use this sampling procedure also during validation to make it closer to the inference procedure. Therefore, only the `MQLoss` is available for validation.\n",
    "\n",
    "Aditionally, Monte Carlo implies that historic exogenous variables are not available for the model.\n",
    ":::"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Figure 1. DeepAR model, during training the optimization signal comes from likelihood of observations, during inference a recurrent multi-step strategy is used to generate predictive distributions.](imgs_models/deepar.jpeg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "import logging\n",
    "import warnings\n",
    "logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR)\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "from typing import Optional\n",
    "\n",
    "from neuralforecast.common._base_windows import BaseWindows\n",
    "from neuralforecast.losses.pytorch import DistributionLoss, MQLoss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from fastcore.test import test_eq\n",
    "from nbdev.showdoc import show_doc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "import logging\n",
    "import warnings\n",
    "logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR)\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class Decoder(nn.Module):\n",
    "    \"\"\"Multi-Layer Perceptron Decoder\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `in_features`: int, dimension of input.<br>\n",
    "    `out_features`: int, dimension of output.<br>\n",
    "    `hidden_size`: int, dimension of hidden layers.<br>\n",
    "    `num_layers`: int, number of hidden layers.<br>\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, in_features, out_features, hidden_size, hidden_layers):\n",
    "        super().__init__()\n",
    "\n",
    "        if hidden_layers == 0:\n",
    "            # Input layer\n",
    "            layers = [nn.Linear(in_features=in_features, out_features=out_features)]\n",
    "        else:\n",
    "            # Input layer\n",
    "            layers = [nn.Linear(in_features=in_features, out_features=hidden_size), nn.ReLU()]\n",
    "            # Hidden layers\n",
    "            for i in range(hidden_layers - 2):\n",
    "                layers += [nn.Linear(in_features=hidden_size, out_features=hidden_size), nn.ReLU()]\n",
    "            # Output layer\n",
    "            layers += [nn.Linear(in_features=hidden_size, out_features=out_features)]\n",
    "\n",
    "        # Store in layers as ModuleList\n",
    "        self.layers = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.layers(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class DeepAR(BaseWindows):\n",
    "    \"\"\" DeepAR\n",
    "\n",
    "    The DeepAR model produces probabilistic forecasts based on an autoregressive recurrent neural network optimized on panel data using cross-learning. DeepAR obtains its forecast distribution uses a Markov Chain Monte Carlo sampler with the following conditional probability:\n",
    "    $$\\mathbb{P}(\\mathbf{y}_{[t+1:t+H]}|\\;\\mathbf{y}_{[:t]},\\; \\mathbf{x}^{(f)}_{[:t+H]},\\; \\mathbf{x}^{(s)})$$\n",
    "\n",
    "    where $\\mathbf{x}^{(s)}$ are static exogenous inputs, $\\mathbf{x}^{(f)}_{[:t+H]}$ are future exogenous available at the time of the prediction.\n",
    "    The predictions are obtained by transforming the hidden states $\\mathbf{h}_{t}$ into predictive distribution parameters $\\theta_{t}$, and then generating samples $\\mathbf{\\hat{y}}_{[t+1:t+H]}$ through Monte Carlo sampling trajectories.\n",
    "\n",
    "    \\begin{align}\n",
    "    \\mathbf{h}_{t} &= \\textrm{RNN}([\\mathbf{y}_{t},\\mathbf{x}^{(f)}_{t+1},\\mathbf{x}^{(s)}], \\mathbf{h}_{t-1})\\\\\n",
    "    \\mathbf{\\theta}_{t}&=\\textrm{Linear}(\\mathbf{h}_{t}) \\\\\n",
    "    \\hat{y}_{t+1}&=\\textrm{sample}(\\;\\mathrm{P}(y_{t+1}\\;|\\;\\mathbf{\\theta}_{t})\\;)\n",
    "    \\end{align}\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `h`: int, Forecast horizon. <br>\n",
    "    `input_size`: int, autorregresive inputs size, y=[1,2,3,4] input_size=2 -> y_[t-2:t]=[1,2].<br>\n",
    "    `lstm_n_layers`: int=2, number of LSTM layers.<br>\n",
    "    `lstm_hidden_size`: int=128, LSTM hidden size.<br>\n",
    "    `lstm_dropout`: float=0.1, LSTM dropout.<br>\n",
    "    `decoder_hidden_layers`: int=0, number of decoder MLP hidden layers. Default: 0 for linear layer. <br>\n",
    "    `decoder_hidden_size`: int=0, decoder MLP hidden size. Default: 0 for linear layer.<br>\n",
    "    `trajectory_samples`: int=100, number of Monte Carlo trajectories during inference.<br>\n",
    "    `stat_exog_list`: str list, static exogenous columns.<br>\n",
    "    `hist_exog_list`: str list, historic exogenous columns.<br>\n",
    "    `futr_exog_list`: str list, future exogenous columns.<br>\n",
    "    `exclude_insample_y`: bool=False, the model skips the autoregressive features y[t-input_size:t] if True.<br>\n",
    "    `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>\n",
    "    `valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>\n",
    "    `max_steps`: int=1000, maximum number of training steps.<br>\n",
    "    `learning_rate`: float=1e-3, Learning rate between (0, 1).<br>\n",
    "    `num_lr_decays`: int=-1, Number of learning rate decays, evenly distributed across max_steps.<br>\n",
    "    `early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.<br>\n",
    "    `val_check_steps`: int=100, Number of training steps between every validation loss check.<br>\n",
    "    `batch_size`: int=32, number of different series in each batch.<br>\n",
    "    `valid_batch_size`: int=None, number of different series in each validation and test batch, if None uses batch_size.<br>\n",
    "    `windows_batch_size`: int=1024, number of windows to sample in each training batch, default uses all.<br>\n",
    "    `inference_windows_batch_size`: int=-1, number of windows to sample in each inference batch, -1 uses all.<br>\n",
    "    `start_padding_enabled`: bool=False, if True, the model will pad the time series with zeros at the beginning, by input size.<br>\n",
    "    `step_size`: int=1, step size between each window of temporal data.<br>\n",
    "    `scaler_type`: str='identity', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).<br>\n",
    "    `random_seed`: int, random_seed for pytorch initializer and numpy generators.<br>\n",
    "    `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.<br>\n",
    "    `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.<br>\n",
    "    `alias`: str, optional,  Custom name of the model.<br>\n",
    "    `**trainer_kwargs`: int,  keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>    \n",
    "\n",
    "    **References**<br>\n",
    "    - [David Salinas, Valentin Flunkert, Jan Gasthaus, Tim Januschowski (2020). \"DeepAR: Probabilistic forecasting with autoregressive recurrent networks\". International Journal of Forecasting.](https://www.sciencedirect.com/science/article/pii/S0169207019301888)<br>\n",
    "    - [Alexander Alexandrov et. al (2020). \"GluonTS: Probabilistic and Neural Time Series Modeling in Python\". Journal of Machine Learning Research.](https://www.jmlr.org/papers/v21/19-820.html)<br>\n",
    "\n",
    "    \"\"\"\n",
    "    # Class attributes\n",
    "    SAMPLING_TYPE = 'windows'\n",
    "    \n",
    "    def __init__(self,\n",
    "                 h,\n",
    "                 input_size: int = -1,\n",
    "                 lstm_n_layers: int = 2,\n",
    "                 lstm_hidden_size: int = 128,\n",
    "                 lstm_dropout: float = 0.1,\n",
    "                 decoder_hidden_layers: int = 0,\n",
    "                 decoder_hidden_size: int = 0,\n",
    "                 trajectory_samples: int = 100,\n",
    "                 futr_exog_list = None,\n",
    "                 hist_exog_list = None,\n",
    "                 stat_exog_list = None,\n",
    "                 exclude_insample_y = False,\n",
    "                 loss = DistributionLoss(distribution='StudentT', level=[80, 90], return_params=False),\n",
    "                 valid_loss = MQLoss(level=[80, 90]),\n",
    "                 max_steps: int = 1000,\n",
    "                 learning_rate: float = 1e-3,\n",
    "                 num_lr_decays: int = 3,\n",
    "                 early_stop_patience_steps: int =-1,\n",
    "                 val_check_steps: int = 100,\n",
    "                 batch_size: int = 32,\n",
    "                 valid_batch_size: Optional[int] = None,\n",
    "                 windows_batch_size: int = 1024,\n",
    "                 inference_windows_batch_size: int = -1,\n",
    "                 start_padding_enabled = False,\n",
    "                 step_size: int = 1,\n",
    "                 scaler_type: str = 'identity',\n",
    "                 random_seed: int = 1,\n",
    "                 num_workers_loader = 0,\n",
    "                 drop_last_loader = False,\n",
    "                 **trainer_kwargs):\n",
    "\n",
    "        # DeepAR does not support historic exogenous variables\n",
    "        if hist_exog_list is not None:\n",
    "            raise Exception('DeepAR does not support historic exogenous variables.')\n",
    "\n",
    "        if exclude_insample_y:\n",
    "            raise Exception('DeepAR has no possibility for excluding y.')\n",
    "        \n",
    "        if not loss.is_distribution_output:\n",
    "            raise Exception('DeepAR only supports distributional outputs.')\n",
    "        \n",
    "        if str(type(valid_loss)) not in [\"<class 'neuralforecast.losses.pytorch.MQLoss'>\"]:\n",
    "            raise Exception('DeepAR only supports MQLoss as validation loss.')\n",
    "\n",
    "        if loss.return_params:\n",
    "            raise Exception('DeepAR does not return distribution parameters due to Monte Carlo sampling.')\n",
    "    \n",
    "        # Inherit BaseWindows class\n",
    "        super(DeepAR, self).__init__(h=h,\n",
    "                                    input_size=input_size,\n",
    "                                    futr_exog_list=futr_exog_list,\n",
    "                                    hist_exog_list=hist_exog_list,\n",
    "                                    stat_exog_list=stat_exog_list,\n",
    "                                    exclude_insample_y = exclude_insample_y,\n",
    "                                    loss=loss,\n",
    "                                    valid_loss=valid_loss,\n",
    "                                    max_steps=max_steps,\n",
    "                                    learning_rate=learning_rate,\n",
    "                                    num_lr_decays=num_lr_decays,\n",
    "                                    early_stop_patience_steps=early_stop_patience_steps,\n",
    "                                    val_check_steps=val_check_steps,\n",
    "                                    batch_size=batch_size,\n",
    "                                    windows_batch_size=windows_batch_size,\n",
    "                                    valid_batch_size=valid_batch_size,\n",
    "                                    inference_windows_batch_size=inference_windows_batch_size,\n",
    "                                    start_padding_enabled=start_padding_enabled,\n",
    "                                    step_size=step_size,\n",
    "                                    scaler_type=scaler_type,\n",
    "                                    num_workers_loader=num_workers_loader,\n",
    "                                    drop_last_loader=drop_last_loader,\n",
    "                                    random_seed=random_seed,\n",
    "                                    **trainer_kwargs)\n",
    "\n",
    "        self.horizon_backup = self.h # Used because h=0 during training\n",
    "        self.trajectory_samples = trajectory_samples\n",
    "\n",
    "        # LSTM\n",
    "        self.encoder_n_layers = lstm_n_layers\n",
    "        self.encoder_hidden_size = lstm_hidden_size\n",
    "        self.encoder_dropout = lstm_dropout\n",
    "\n",
    "        self.futr_exog_size = len(self.futr_exog_list)\n",
    "        self.hist_exog_size = 0\n",
    "        self.stat_exog_size = len(self.stat_exog_list)\n",
    "        \n",
    "        # LSTM input size (1 for target variable y)\n",
    "        input_encoder = 1 + self.futr_exog_size + self.stat_exog_size\n",
    "\n",
    "        # Instantiate model\n",
    "        self.hist_encoder = nn.LSTM(input_size=input_encoder,\n",
    "                                    hidden_size=self.encoder_hidden_size,\n",
    "                                    num_layers=self.encoder_n_layers,\n",
    "                                    dropout=self.encoder_dropout,\n",
    "                                    batch_first=True)\n",
    "\n",
    "        # Decoder MLP\n",
    "        self.decoder = Decoder(in_features=lstm_hidden_size,\n",
    "                               out_features=self.loss.outputsize_multiplier,\n",
    "                               hidden_size=decoder_hidden_size,\n",
    "                               hidden_layers=decoder_hidden_layers)\n",
    "\n",
    "    # Override BaseWindows method\n",
    "    def training_step(self, batch, batch_idx):\n",
    "\n",
    "        # During training h=0  \n",
    "        self.h = 0\n",
    "        y_idx = batch['y_idx']\n",
    "\n",
    "        # Create and normalize windows [Ws, L, C]\n",
    "        windows = self._create_windows(batch, step='train')\n",
    "        original_insample_y = windows['temporal'][:, :, y_idx].clone() # windows: [B, L, Feature] -> [B, L]\n",
    "        original_insample_y = original_insample_y[:,1:] # Remove first (shift in DeepAr, cell at t outputs t+1)\n",
    "        windows = self._normalization(windows=windows, y_idx=y_idx)\n",
    "\n",
    "        # Parse windows\n",
    "        insample_y, insample_mask, _, _, _, futr_exog, stat_exog = self._parse_windows(batch, windows)\n",
    "\n",
    "        windows_batch = dict(insample_y=insample_y, # [Ws, L]\n",
    "                             insample_mask=insample_mask, # [Ws, L]\n",
    "                             futr_exog=futr_exog, # [Ws, L+H]\n",
    "                             hist_exog=None, # None\n",
    "                             stat_exog=stat_exog,\n",
    "                             y_idx=y_idx) # [Ws, 1]\n",
    "\n",
    "        # Model Predictions\n",
    "        output = self.train_forward(windows_batch)\n",
    "\n",
    "        if self.loss.is_distribution_output:\n",
    "            _, y_loc, y_scale = self._inv_normalization(y_hat=original_insample_y,\n",
    "                                            temporal_cols=batch['temporal_cols'],\n",
    "                                            y_idx=y_idx)\n",
    "            outsample_y = original_insample_y\n",
    "            distr_args = self.loss.scale_decouple(output=output, loc=y_loc, scale=y_scale)\n",
    "            mask = insample_mask[:,1:].clone() # Remove first (shift in DeepAr, cell at t outputs t+1)\n",
    "            loss = self.loss(y=outsample_y, distr_args=distr_args, mask=mask)\n",
    "        else:\n",
    "            raise Exception('DeepAR only supports distributional outputs.')\n",
    "\n",
    "        if torch.isnan(loss):\n",
    "            print('Model Parameters', self.hparams)\n",
    "            print('insample_y', torch.isnan(insample_y).sum())\n",
    "            print('outsample_y', torch.isnan(outsample_y).sum())\n",
    "            print('output', torch.isnan(output).sum())\n",
    "            raise Exception('Loss is NaN, training stopped.')\n",
    "\n",
    "        self.log('train_loss', loss, prog_bar=True, on_epoch=True)\n",
    "        self.train_trajectories.append((self.global_step, float(loss)))\n",
    "\n",
    "        self.h = self.horizon_backup # Restore horizon\n",
    "        return loss\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "\n",
    "        self.h == self.horizon_backup\n",
    "\n",
    "        if self.val_size == 0:\n",
    "            return np.nan\n",
    "\n",
    "        # TODO: Hack to compute number of windows\n",
    "        windows = self._create_windows(batch, step='val')\n",
    "        n_windows = len(windows['temporal'])\n",
    "        y_idx = batch['y_idx']\n",
    "\n",
    "        # Number of windows in batch\n",
    "        windows_batch_size = self.inference_windows_batch_size\n",
    "        if windows_batch_size < 0:\n",
    "            windows_batch_size = n_windows\n",
    "        n_batches = int(np.ceil(n_windows/windows_batch_size))\n",
    "\n",
    "        valid_losses = []\n",
    "        batch_sizes = []\n",
    "        for i in range(n_batches):\n",
    "            # Create and normalize windows [Ws, L+H, C]\n",
    "            w_idxs = np.arange(i*windows_batch_size, \n",
    "                               min((i+1)*windows_batch_size, n_windows))\n",
    "            windows = self._create_windows(batch, step='val', w_idxs=w_idxs)\n",
    "            original_outsample_y = torch.clone(windows['temporal'][:,-self.h:,0])\n",
    "            windows = self._normalization(windows=windows, y_idx=y_idx)\n",
    "\n",
    "            # Parse windows\n",
    "            insample_y, insample_mask, _, outsample_mask, \\\n",
    "                _, futr_exog, stat_exog = self._parse_windows(batch, windows)\n",
    "            windows_batch = dict(insample_y=insample_y,\n",
    "                        insample_mask=insample_mask,\n",
    "                        futr_exog=futr_exog,\n",
    "                        hist_exog=None,\n",
    "                        stat_exog=stat_exog,\n",
    "                        temporal_cols=batch['temporal_cols'],\n",
    "                        y_idx=y_idx) \n",
    "            \n",
    "            # Model Predictions\n",
    "            output_batch = self(windows_batch)\n",
    "            # Monte Carlo already returns y_hat with mean and quantiles\n",
    "            output_batch = output_batch[:,:, 1:] # Remove mean\n",
    "            valid_loss_batch = self.valid_loss(y=original_outsample_y, y_hat=output_batch, mask=outsample_mask)\n",
    "            valid_losses.append(valid_loss_batch)\n",
    "            batch_sizes.append(len(output_batch))\n",
    "\n",
    "        valid_loss = torch.stack(valid_losses)\n",
    "        batch_sizes = torch.tensor(batch_sizes).to(valid_loss.device)\n",
    "        valid_loss = torch.sum(valid_loss * batch_sizes) \\\n",
    "                        / torch.sum(batch_sizes)\n",
    "\n",
    "        if torch.isnan(valid_loss):\n",
    "            raise Exception('Loss is NaN, training stopped.')\n",
    "\n",
    "        self.log('valid_loss', valid_loss, prog_bar=True, on_epoch=True)\n",
    "        self.validation_step_outputs.append(valid_loss)\n",
    "        return valid_loss\n",
    "\n",
    "    def predict_step(self, batch, batch_idx):\n",
    "\n",
    "        self.h == self.horizon_backup\n",
    "\n",
    "        # TODO: Hack to compute number of windows\n",
    "        windows = self._create_windows(batch, step='predict')\n",
    "        n_windows = len(windows['temporal'])\n",
    "        y_idx = batch['y_idx']\n",
    "\n",
    "        # Number of windows in batch\n",
    "        windows_batch_size = self.inference_windows_batch_size\n",
    "        if windows_batch_size < 0:\n",
    "            windows_batch_size = n_windows\n",
    "        n_batches = int(np.ceil(n_windows/windows_batch_size))\n",
    "\n",
    "        y_hats = []\n",
    "        for i in range(n_batches):\n",
    "            # Create and normalize windows [Ws, L+H, C]\n",
    "            w_idxs = np.arange(i*windows_batch_size, \n",
    "                    min((i+1)*windows_batch_size, n_windows))\n",
    "            windows = self._create_windows(batch, step='predict', w_idxs=w_idxs)\n",
    "            windows = self._normalization(windows=windows, y_idx=y_idx)\n",
    "\n",
    "            # Parse windows\n",
    "            insample_y, insample_mask, _, _, _, futr_exog, stat_exog = self._parse_windows(batch, windows)\n",
    "            windows_batch = dict(insample_y=insample_y, # [Ws, L]\n",
    "                                insample_mask=insample_mask, # [Ws, L]\n",
    "                                futr_exog=futr_exog, # [Ws, L+H]\n",
    "                                stat_exog=stat_exog,\n",
    "                                temporal_cols=batch['temporal_cols'],\n",
    "                                y_idx=y_idx)\n",
    "            \n",
    "            # Model Predictions\n",
    "            y_hat = self(windows_batch)\n",
    "            # Monte Carlo already returns y_hat with mean and quantiles\n",
    "            y_hats.append(y_hat)\n",
    "        y_hat = torch.cat(y_hats, dim=0)\n",
    "        return y_hat\n",
    "\n",
    "    def train_forward(self, windows_batch):\n",
    "\n",
    "        # Parse windows_batch\n",
    "        encoder_input = windows_batch['insample_y'][:,:, None] # <- [B,T,1]\n",
    "        futr_exog  = windows_batch['futr_exog']\n",
    "        stat_exog  = windows_batch['stat_exog']\n",
    "\n",
    "        #[B, input_size-1, X]\n",
    "        encoder_input = encoder_input[:,:-1,:] # Remove last (shift in DeepAr, cell at t outputs t+1)\n",
    "        _, input_size = encoder_input.shape[:2]\n",
    "        if self.futr_exog_size > 0:\n",
    "            # Shift futr_exog (t predicts t+1, last output is outside insample_y)\n",
    "            encoder_input = torch.cat((encoder_input, futr_exog[:,1:,:]), dim=2)\n",
    "        if self.stat_exog_size > 0:\n",
    "            stat_exog = stat_exog.unsqueeze(1).repeat(1, input_size, 1) # [B, S] -> [B, input_size-1, S]\n",
    "            encoder_input = torch.cat((encoder_input, stat_exog), dim=2)\n",
    "\n",
    "        # RNN forward\n",
    "        hidden_state, _ = self.hist_encoder(encoder_input) # [B, input_size-1, rnn_hidden_state]\n",
    "\n",
    "        # Decoder forward\n",
    "        output = self.decoder(hidden_state) # [B, input_size-1, output_size]\n",
    "        output = self.loss.domain_map(output)\n",
    "        return output\n",
    "    \n",
    "    def forward(self, windows_batch):\n",
    "\n",
    "        # Parse windows_batch\n",
    "        encoder_input = windows_batch['insample_y'][:,:, None] # <- [B,L,1]\n",
    "        futr_exog  = windows_batch['futr_exog'] # <- [B,L+H, n_f]\n",
    "        stat_exog  = windows_batch['stat_exog']\n",
    "        y_idx = windows_batch['y_idx']\n",
    "\n",
    "        #[B, seq_len, X]\n",
    "        batch_size, input_size = encoder_input.shape[:2]\n",
    "        if self.futr_exog_size > 0:\n",
    "            futr_exog_input_window = futr_exog[:,1:input_size+1,:] # Align y_t with futr_exog_t+1\n",
    "            encoder_input = torch.cat((encoder_input, futr_exog_input_window), dim=2)\n",
    "        if self.stat_exog_size > 0:\n",
    "            stat_exog_input_window = stat_exog.unsqueeze(1).repeat(1, input_size, 1) # [B, S] -> [B, input_size, S]\n",
    "            encoder_input = torch.cat((encoder_input, stat_exog_input_window), dim=2)\n",
    "\n",
    "        # Use input_size history to predict first h of the forecasting window\n",
    "        _, h_c_tuple = self.hist_encoder(encoder_input)\n",
    "        h_n = h_c_tuple[0] # [n_layers, B, lstm_hidden_state]\n",
    "        c_n = h_c_tuple[1] # [n_layers, B, lstm_hidden_state]\n",
    "\n",
    "        # Vectorizes trajectory samples in batch dimension [1]\n",
    "        h_n = torch.repeat_interleave(h_n, self.trajectory_samples, 1) # [n_layers, B*trajectory_samples, rnn_hidden_state]\n",
    "        c_n = torch.repeat_interleave(c_n, self.trajectory_samples, 1) # [n_layers, B*trajectory_samples, rnn_hidden_state]\n",
    "\n",
    "        # Scales for inverse normalization\n",
    "        y_scale = self.scaler.x_scale[:, 0, [y_idx]].squeeze(-1).to(encoder_input.device)\n",
    "        y_loc = self.scaler.x_shift[:, 0, [y_idx]].squeeze(-1).to(encoder_input.device)\n",
    "        y_scale = torch.repeat_interleave(y_scale, self.trajectory_samples, 0)\n",
    "        y_loc = torch.repeat_interleave(y_loc, self.trajectory_samples, 0)\n",
    "\n",
    "        # Recursive strategy prediction\n",
    "        quantiles = self.loss.quantiles.to(encoder_input.device)\n",
    "        y_hat = torch.zeros(batch_size, self.h, len(quantiles)+1).to(encoder_input.device)\n",
    "        for tau in range(self.h):\n",
    "            # Decoder forward\n",
    "            last_layer_h = h_n[-1] # [B*trajectory_samples, lstm_hidden_state]\n",
    "            output = self.decoder(last_layer_h) \n",
    "            output = self.loss.domain_map(output)\n",
    "\n",
    "            # Inverse normalization\n",
    "            distr_args = self.loss.scale_decouple(output=output, loc=y_loc, scale=y_scale)\n",
    "            # Add horizon (1) dimension\n",
    "            distr_args = list(distr_args)\n",
    "            for i in range(len(distr_args)):\n",
    "                distr_args[i] = distr_args[i].unsqueeze(-1)\n",
    "            distr_args = tuple(distr_args)\n",
    "            samples_tau, _, _ = self.loss.sample(distr_args=distr_args, num_samples=1)\n",
    "            samples_tau = samples_tau.reshape(batch_size, self.trajectory_samples)\n",
    "            sample_mean = torch.mean(samples_tau, dim=-1).to(encoder_input.device)\n",
    "            quants = torch.quantile(input=samples_tau, \n",
    "                                    q=quantiles, dim=-1).to(encoder_input.device)\n",
    "            y_hat[:,tau,0] = sample_mean\n",
    "            y_hat[:,tau,1:] = quants.permute((1,0)) # [Q, B] -> [B, Q]\n",
    "            \n",
    "            # Stop if already in the last step (no need to predict next step)\n",
    "            if tau+1 == self.h:\n",
    "                continue\n",
    "            # Normalize to use as input\n",
    "            encoder_input = self.scaler.scaler(samples_tau.flatten(), y_loc, y_scale) # [B*n_samples]\n",
    "            encoder_input = encoder_input[:, None, None] # [B*n_samples, 1, 1]\n",
    "\n",
    "            # Update input\n",
    "            if self.futr_exog_size > 0:\n",
    "                futr_exog_tau = futr_exog[:,[input_size+tau+1],:] # [B, 1, n_f]\n",
    "                futr_exog_tau = torch.repeat_interleave(futr_exog_tau, self.trajectory_samples, 0) # [B*n_samples, 1, n_f]\n",
    "                encoder_input = torch.cat((encoder_input, futr_exog_tau), dim=2) # [B*n_samples, 1, 1+n_f]\n",
    "            if self.stat_exog_size > 0:\n",
    "                stat_exog_tau = torch.repeat_interleave(stat_exog, self.trajectory_samples, 0) # [B*n_samples, n_s]\n",
    "                encoder_input = torch.cat((encoder_input, stat_exog_tau[:,None,:]), dim=2) # [B*n_samples, 1, 1+n_f+n_s]\n",
    "            \n",
    "            _, h_c_tuple = self.hist_encoder(encoder_input, (h_n, c_n))\n",
    "            h_n = h_c_tuple[0] # [n_layers, B, rnn_hidden_state]\n",
    "            c_n = h_c_tuple[1] # [n_layers, B, rnn_hidden_state]\n",
    "\n",
    "        return y_hat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(DeepAR, title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(DeepAR.fit, name='DeepAR.fit', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(DeepAR.predict, name='DeepAR.predict', title_level=3)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Usage Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from neuralforecast import NeuralForecast\n",
    "from neuralforecast.losses.pytorch import MQLoss, DistributionLoss, GMM, PMM\n",
    "from neuralforecast.tsdataset import TimeSeriesDataset\n",
    "from neuralforecast.utils import AirPassengers, AirPassengersPanel, AirPassengersStatic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| eval: false\n",
    "import pandas as pd\n",
    "import pytorch_lightning as pl\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from neuralforecast import NeuralForecast\n",
    "#from neuralforecast.models import DeepAR\n",
    "from neuralforecast.losses.pytorch import DistributionLoss, HuberMQLoss\n",
    "from neuralforecast.tsdataset import TimeSeriesDataset\n",
    "from neuralforecast.utils import AirPassengers, AirPassengersPanel, AirPassengersStatic\n",
    "\n",
    "#AirPassengersPanel['y'] = AirPassengersPanel['y'] + 10\n",
    "Y_train_df = AirPassengersPanel[AirPassengersPanel.ds<AirPassengersPanel['ds'].values[-12]] # 132 train\n",
    "Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test\n",
    "\n",
    "nf = NeuralForecast(\n",
    "    models=[DeepAR(h=12,\n",
    "                   input_size=48,\n",
    "                   lstm_n_layers=3,\n",
    "                   trajectory_samples=100,\n",
    "                   loss=DistributionLoss(distribution='Normal', level=[80, 90], return_params=False),\n",
    "                   learning_rate=0.005,\n",
    "                   stat_exog_list=['airline1'],\n",
    "                   futr_exog_list=['trend'],\n",
    "                   max_steps=100,\n",
    "                   val_check_steps=10,\n",
    "                   early_stop_patience_steps=-1,\n",
    "                   scaler_type='standard',\n",
    "                   enable_progress_bar=True),\n",
    "    ],\n",
    "    freq='M'\n",
    ")\n",
    "nf.fit(df=Y_train_df, static_df=AirPassengersStatic, val_size=12)\n",
    "Y_hat_df = nf.predict(futr_df=Y_test_df)\n",
    "\n",
    "# Plot quantile predictions\n",
    "Y_hat_df = Y_hat_df.reset_index(drop=False).drop(columns=['unique_id','ds'])\n",
    "plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)\n",
    "plot_df = pd.concat([Y_train_df, plot_df])\n",
    "\n",
    "plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)\n",
    "plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n",
    "#plt.plot(plot_df['ds'], plot_df['DeepAR'], c='purple', label='mean')\n",
    "plt.plot(plot_df['ds'], plot_df['DeepAR-median'], c='blue', label='median')\n",
    "plt.fill_between(x=plot_df['ds'][-12:], \n",
    "                 y1=plot_df['DeepAR-lo-90'][-12:].values, \n",
    "                 y2=plot_df['DeepAR-hi-90'][-12:].values,\n",
    "                 alpha=0.4, label='level 90')\n",
    "plt.legend()\n",
    "plt.grid()\n",
    "plt.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
